{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "k2QMetc6kava"
      },
      "outputs": [],
      "source": [
        "#@title ##### License\n",
        "# Copyright 2018 The GraphNets Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#    http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# ============================================================================"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "c5CPvyHM2CnU"
      },
      "source": [
        "# Physical dynamics of a mass-spring system\n",
        "This notebook and the accompanying code demonstrates how to use the Graph Nets library to learn to predict the motion of a set of masses connected by springs.\n",
        "\n",
        "The network is trained to predict the behaviour of a chain of five masses, connected by identical springs. The first and last masses are fixed; the others are subject to gravity.\n",
        "\n",
        "After training, the network's prediction ability is illustrated by comparing its output to the true behaviour of the structure. Then the network's ability to generalise is tested, by using it to predict the behaviour of a similar but more complicated mass/spring structure."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "Ss54UNGvkz5M"
      },
      "outputs": [],
      "source": [
        "#@title ### Install the Graph Nets library on this Colaboratory runtime  { form-width: \"60%\", run: \"auto\"}\n",
        "#@markdown \u003cbr\u003e1. Connect to a local or hosted Colaboratory runtime by clicking the **Connect** button at the top-right.\u003cbr\u003e2. Choose \"Yes\" below to install the Graph Nets library on the runtime machine with the correct dependencies. Note, this works both with local and hosted Colaboratory runtimes.\n",
        "\n",
        "install_graph_nets_library = \"No\"  #@param [\"Yes\", \"No\"]\n",
        "\n",
        "if install_graph_nets_library.lower() == \"yes\":\n",
        "  print(\"Installing Graph Nets library and dependencies:\")\n",
        "  print(\"Output message from command:\\n\")\n",
        "  !pip install graph_nets \"dm-sonnet\u003c2\" \"tensorflow_probability\u003c0.9\"\n",
        "else:\n",
        "  print(\"Skipping installation of Graph Nets library\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7E4elJkXFR4a"
      },
      "source": [
        "### Install dependencies locally\n",
        "\n",
        "If you are running this notebook locally (i.e., not through Colaboratory), you will also need to install a few more dependencies. Run the following on the command line to install the graph networks library, as well as a few other dependencies:\n",
        "\n",
        "```\n",
        "pip install graph_nets matplotlib scipy \"tensorflow\u003e=1.15,\u003c2\" \"dm-sonnet\u003c2\" \"tensorflow_probability\u003c0.9\"\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hyVaNA-bGug3"
      },
      "source": [
        "# Code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "RzKvtoAgRA8e"
      },
      "outputs": [],
      "source": [
        "#@title Imports  { form-width: \"30%\" }\n",
        "\n",
        "# The demo dependencies are not installed with the library, but you can install\n",
        "# them with:\n",
        "#\n",
        "# $ pip install jupyter matplotlib scipy\n",
        "#\n",
        "# Run the demo with:\n",
        "#\n",
        "# $ jupyter notebook \u003cpath\u003e/\u003cto\u003e/\u003cdemos\u003e/shortest_path.ipynb\n",
        "%tensorflow_version 1.x  # For Google Colab only.\n",
        "\n",
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import time\n",
        "\n",
        "from graph_nets import blocks\n",
        "from graph_nets import utils_tf\n",
        "from graph_nets.demos import models\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import sonnet as snt\n",
        "import tensorflow as tf\n",
        "\n",
        "\n",
        "try:\n",
        "  import seaborn as sns\n",
        "except ImportError:\n",
        "  pass\n",
        "else:\n",
        "  sns.reset_orig()\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED)\n",
        "tf.set_random_seed(SEED)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "toCQhJIM93en"
      },
      "outputs": [],
      "source": [
        "#@title Helper functions  { form-width: \"30%\" }\n",
        "\n",
        "# pylint: disable=redefined-outer-name\n",
        "\n",
        "def base_graph(n, d):\n",
        "  \"\"\"Define a basic mass-spring system graph structure.\n",
        "\n",
        "  These are n masses (1kg) connected by springs in a chain-like structure. The\n",
        "  first and last masses are fixed. The masses are vertically aligned at the\n",
        "  start and are d meters apart; this is also the rest length for the springs\n",
        "  connecting them. Springs have spring constant 50 N/m and gravity is 10 N in\n",
        "  the negative y-direction.\n",
        "\n",
        "  Args:\n",
        "    n: number of masses\n",
        "    d: distance between masses (as well as springs' rest length)\n",
        "\n",
        "  Returns:\n",
        "    data_dict: dictionary with globals, nodes, edges, receivers and senders\n",
        "        to represent a structure like the one above.\n",
        "  \"\"\"\n",
        "  # Nodes\n",
        "  # Generate initial position and velocity for all masses.\n",
        "  # The left-most mass has is at position (0, 0); other masses (ordered left to\n",
        "  # right) have x-coordinate d meters apart from their left neighbor, and\n",
        "  # y-coordinate 0. All masses have initial velocity 0m/s.\n",
        "  nodes = np.zeros((n, 5), dtype=np.float32)\n",
        "  half_width = d * n / 2.0\n",
        "  nodes[:, 0] = np.linspace(\n",
        "      -half_width, half_width, num=n, endpoint=False, dtype=np.float32)\n",
        "  # indicate that the first and last masses are fixed\n",
        "  nodes[(0, -1), -1] = 1.\n",
        "\n",
        "  # Edges.\n",
        "  edges, senders, receivers = [], [], []\n",
        "  for i in range(n - 1):\n",
        "    left_node = i\n",
        "    right_node = i + 1\n",
        "    # The 'if' statements prevent incoming edges to fixed ends of the string.\n",
        "    if right_node \u003c n - 1:\n",
        "      # Left incoming edge.\n",
        "      edges.append([50., d])\n",
        "      senders.append(left_node)\n",
        "      receivers.append(right_node)\n",
        "    if left_node \u003e 0:\n",
        "      # Right incoming edge.\n",
        "      edges.append([50., d])\n",
        "      senders.append(right_node)\n",
        "      receivers.append(left_node)\n",
        "\n",
        "  return {\n",
        "      \"globals\": [0., -10.],\n",
        "      \"nodes\": nodes,\n",
        "      \"edges\": edges,\n",
        "      \"receivers\": receivers,\n",
        "      \"senders\": senders\n",
        "  }\n",
        "\n",
        "\n",
        "def hookes_law(receiver_nodes, sender_nodes, k, x_rest):\n",
        "  \"\"\"Applies Hooke's law to springs connecting some nodes.\n",
        "\n",
        "  Args:\n",
        "    receiver_nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for the\n",
        "      receiver node of each edge.\n",
        "    sender_nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for the\n",
        "      sender node of each edge.\n",
        "    k: Spring constant for each edge.\n",
        "    x_rest: Rest length of each edge.\n",
        "\n",
        "  Returns:\n",
        "    Nx2 Tensor of the force [f_x, f_y] acting on each edge.\n",
        "  \"\"\"\n",
        "  diff = receiver_nodes[..., 0:2] - sender_nodes[..., 0:2]\n",
        "  x = tf.norm(diff, axis=-1, keepdims=True)\n",
        "  force_magnitude = -1 * tf.multiply(k, (x - x_rest) / x)\n",
        "  force = force_magnitude * diff\n",
        "  return force\n",
        "\n",
        "\n",
        "def euler_integration(nodes, force_per_node, step_size):\n",
        "  \"\"\"Applies one step of Euler integration.\n",
        "\n",
        "  Args:\n",
        "    nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for each node.\n",
        "    force_per_node: Ex2 tf.Tensor of the force [f_x, f_y] acting on each edge.\n",
        "    step_size: Scalar.\n",
        "\n",
        "  Returns:\n",
        "    A tf.Tensor of the same shape as `nodes` but with positions and velocities\n",
        "        updated.\n",
        "  \"\"\"\n",
        "  is_fixed = nodes[..., 4:5]\n",
        "  # set forces to zero for fixed nodes\n",
        "  force_per_node *= 1 - is_fixed\n",
        "  new_vel = nodes[..., 2:4] + force_per_node * step_size\n",
        "  return new_vel\n",
        "\n",
        "\n",
        "class SpringMassSimulator(snt.AbstractModule):\n",
        "  \"\"\"Implements a basic Physics Simulator using the blocks library.\"\"\"\n",
        "\n",
        "  def __init__(self, step_size, name=\"SpringMassSimulator\"):\n",
        "    super(SpringMassSimulator, self).__init__(name=name)\n",
        "    self._step_size = step_size\n",
        "\n",
        "    with self._enter_variable_scope():\n",
        "      self._aggregator = blocks.ReceivedEdgesToNodesAggregator(\n",
        "          reducer=tf.unsorted_segment_sum)\n",
        "\n",
        "  def _build(self, graph):\n",
        "    \"\"\"Builds a SpringMassSimulator.\n",
        "\n",
        "    Args:\n",
        "      graph: A graphs.GraphsTuple having, for some integers N, E, G:\n",
        "          - edges: Nx2 tf.Tensor of [spring_constant, rest_length] for each\n",
        "            edge.\n",
        "          - nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for each\n",
        "            node.\n",
        "          - globals: Gx2 tf.Tensor containing the gravitational constant.\n",
        "\n",
        "    Returns:\n",
        "      A graphs.GraphsTuple of the same shape as `graph`, but where:\n",
        "          - edges: Holds the force [f_x, f_y] acting on each edge.\n",
        "          - nodes: Holds positions and velocities after applying one step of\n",
        "              Euler integration.\n",
        "    \"\"\"\n",
        "    receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph)\n",
        "    sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph)\n",
        "\n",
        "    spring_force_per_edge = hookes_law(receiver_nodes, sender_nodes,\n",
        "                                       graph.edges[..., 0:1],\n",
        "                                       graph.edges[..., 1:2])\n",
        "    graph = graph.replace(edges=spring_force_per_edge)\n",
        "\n",
        "    spring_force_per_node = self._aggregator(graph)\n",
        "    gravity = blocks.broadcast_globals_to_nodes(graph)\n",
        "    updated_velocities = euler_integration(\n",
        "        graph.nodes, spring_force_per_node + gravity, self._step_size)\n",
        "    graph = graph.replace(nodes=updated_velocities)\n",
        "    return graph\n",
        "\n",
        "\n",
        "def prediction_to_next_state(input_graph, predicted_graph, step_size):\n",
        "  # manually integrate velocities to compute new positions\n",
        "  new_pos = input_graph.nodes[..., :2] + predicted_graph.nodes * step_size\n",
        "  new_nodes = tf.concat(\n",
        "      [new_pos, predicted_graph.nodes, input_graph.nodes[..., 4:5]], axis=-1)\n",
        "  return input_graph.replace(nodes=new_nodes)\n",
        "\n",
        "\n",
        "def roll_out_physics(simulator, graph, steps, step_size):\n",
        "  \"\"\"Apply some number of steps of physical laws to an interaction network.\n",
        "\n",
        "  Args:\n",
        "    simulator: A SpringMassSimulator, or some module or callable with the same\n",
        "      signature.\n",
        "    graph: A graphs.GraphsTuple having, for some integers N, E, G:\n",
        "        - edges: Nx2 tf.Tensor of [spring_constant, rest_length] for each edge.\n",
        "        - nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for each\n",
        "          node.\n",
        "        - globals: Gx2 tf.Tensor containing the gravitational constant.\n",
        "    steps: An integer.\n",
        "    step_size: Scalar.\n",
        "\n",
        "  Returns:\n",
        "    A pair of:\n",
        "    - The graph, updated after `steps` steps of simulation;\n",
        "    - A `steps+1`xNx5 tf.Tensor of the node features at each step.\n",
        "  \"\"\"\n",
        "\n",
        "  def body(t, graph, nodes_per_step):\n",
        "    predicted_graph = simulator(graph)\n",
        "    if isinstance(predicted_graph, list):\n",
        "      predicted_graph = predicted_graph[-1]\n",
        "    graph = prediction_to_next_state(graph, predicted_graph, step_size)\n",
        "    return t + 1, graph, nodes_per_step.write(t, graph.nodes)\n",
        "\n",
        "  nodes_per_step = tf.TensorArray(\n",
        "      dtype=graph.nodes.dtype, size=steps + 1, element_shape=graph.nodes.shape)\n",
        "  nodes_per_step = nodes_per_step.write(0, graph.nodes)\n",
        "\n",
        "  _, g, nodes_per_step = tf.while_loop(\n",
        "      lambda t, *unused_args: t \u003c= steps,\n",
        "      body,\n",
        "      loop_vars=[1, graph, nodes_per_step])\n",
        "  return g, nodes_per_step.stack()\n",
        "\n",
        "\n",
        "def apply_noise(graph, node_noise_level, edge_noise_level, global_noise_level):\n",
        "  \"\"\"Applies uniformly-distributed noise to a graph of a physical system.\n",
        "\n",
        "  Noise is applied to:\n",
        "  - the x and y coordinates (independently) of the nodes;\n",
        "  - the spring constants of the edges;\n",
        "  - the y coordinate of the global gravitational constant.\n",
        "\n",
        "  Args:\n",
        "    graph: a graphs.GraphsTuple having, for some integers N, E, G:\n",
        "        - nodes: Nx5 Tensor of [x, y, _, _, _] for each node.\n",
        "        - edges: Ex2 Tensor of [spring_constant, _] for each edge.\n",
        "        - globals: Gx2 tf.Tensor containing the gravitational constant.\n",
        "    node_noise_level: Maximum distance to perturb nodes' x and y coordinates.\n",
        "    edge_noise_level: Maximum amount to perturb edge spring constants.\n",
        "    global_noise_level: Maximum amount to perturb the Y component of gravity.\n",
        "\n",
        "  Returns:\n",
        "    The input graph, but with noise applied.\n",
        "  \"\"\"\n",
        "  node_position_noise = tf.random_uniform(\n",
        "      [graph.nodes.shape[0].value, 2],\n",
        "      minval=-node_noise_level,\n",
        "      maxval=node_noise_level)\n",
        "  edge_spring_constant_noise = tf.random_uniform(\n",
        "      [graph.edges.shape[0].value, 1],\n",
        "      minval=-edge_noise_level,\n",
        "      maxval=edge_noise_level)\n",
        "  global_gravity_y_noise = tf.random_uniform(\n",
        "      [graph.globals.shape[0].value, 1],\n",
        "      minval=-global_noise_level,\n",
        "      maxval=global_noise_level)\n",
        "\n",
        "  return graph.replace(\n",
        "      nodes=tf.concat(\n",
        "          [graph.nodes[..., :2] + node_position_noise, graph.nodes[..., 2:]],\n",
        "          axis=-1),\n",
        "      edges=tf.concat(\n",
        "          [\n",
        "              graph.edges[..., :1] + edge_spring_constant_noise,\n",
        "              graph.edges[..., 1:]\n",
        "          ],\n",
        "          axis=-1),\n",
        "      globals=tf.concat(\n",
        "          [\n",
        "              graph.globals[..., :1],\n",
        "              graph.globals[..., 1:] + global_gravity_y_noise\n",
        "          ],\n",
        "          axis=-1))\n",
        "\n",
        "\n",
        "def set_rest_lengths(graph):\n",
        "  \"\"\"Computes and sets rest lengths for the springs in a physical system.\n",
        "\n",
        "  The rest length is taken to be the distance between each edge's nodes.\n",
        "\n",
        "  Args:\n",
        "    graph: a graphs.GraphsTuple having, for some integers N, E:\n",
        "        - nodes: Nx5 Tensor of [x, y, _, _, _] for each node.\n",
        "        - edges: Ex2 Tensor of [spring_constant, _] for each edge.\n",
        "\n",
        "  Returns:\n",
        "    The input graph, but with [spring_constant, rest_length] for each edge.\n",
        "  \"\"\"\n",
        "  receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph)\n",
        "  sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph)\n",
        "  rest_length = tf.norm(\n",
        "      receiver_nodes[..., :2] - sender_nodes[..., :2], axis=-1, keepdims=True)\n",
        "  return graph.replace(\n",
        "      edges=tf.concat([graph.edges[..., :1], rest_length], axis=-1))\n",
        "\n",
        "\n",
        "def generate_trajectory(simulator, graph, steps, step_size, node_noise_level,\n",
        "                        edge_noise_level, global_noise_level):\n",
        "  \"\"\"Applies noise and then simulates a physical system for a number of steps.\n",
        "\n",
        "  Args:\n",
        "    simulator: A SpringMassSimulator, or some module or callable with the same\n",
        "      signature.\n",
        "    graph: a graphs.GraphsTuple having, for some integers N, E, G:\n",
        "        - nodes: Nx5 Tensor of [x, y, v_x, v_y, is_fixed] for each node.\n",
        "        - edges: Ex2 Tensor of [spring_constant, _] for each edge.\n",
        "        - globals: Gx2 tf.Tensor containing the gravitational constant.\n",
        "    steps: Integer; the length of trajectory to generate.\n",
        "    step_size: Scalar.\n",
        "    node_noise_level: Maximum distance to perturb nodes' x and y coordinates.\n",
        "    edge_noise_level: Maximum amount to perturb edge spring constants.\n",
        "    global_noise_level: Maximum amount to perturb the Y component of gravity.\n",
        "\n",
        "  Returns:\n",
        "    A pair of:\n",
        "    - The input graph, but with rest lengths computed and noise applied.\n",
        "    - A `steps+1`xNx5 tf.Tensor of the node features at each step.\n",
        "  \"\"\"\n",
        "  graph = apply_noise(graph, node_noise_level, edge_noise_level,\n",
        "                      global_noise_level)\n",
        "  graph = set_rest_lengths(graph)\n",
        "  _, n = roll_out_physics(simulator, graph, steps, step_size)\n",
        "  return graph, n\n",
        "\n",
        "\n",
        "def create_loss_ops(target_op, output_ops):\n",
        "  \"\"\"Create supervised loss operations from targets and outputs.\n",
        "\n",
        "  Args:\n",
        "    target_op: The target velocity tf.Tensor.\n",
        "    output_ops: The list of output graphs from the model.\n",
        "\n",
        "  Returns:\n",
        "    A list of loss values (tf.Tensor), one per output op.\n",
        "  \"\"\"\n",
        "  loss_ops = [\n",
        "      tf.reduce_mean(\n",
        "          tf.reduce_sum((output_op.nodes - target_op[..., 2:4])**2, axis=-1))\n",
        "      for output_op in output_ops\n",
        "  ]\n",
        "  return loss_ops\n",
        "\n",
        "\n",
        "def make_all_runnable_in_session(*args):\n",
        "  \"\"\"Apply make_runnable_in_session to an iterable of graphs.\"\"\"\n",
        "  return [utils_tf.make_runnable_in_session(a) for a in args]\n",
        "\n",
        "\n",
        "# pylint: enable=redefined-outer-name"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "pf7u0zuN_ktd"
      },
      "outputs": [],
      "source": [
        "#@title Set up model training and evaluation  { form-width: \"30%\" }\n",
        "\n",
        "# The model we explore includes three components:\n",
        "# - An \"Encoder\" graph net, which independently encodes the edge, node, and\n",
        "#   global attributes (does not compute relations etc.).\n",
        "# - A \"Core\" graph net, which performs N rounds of processing (message-passing)\n",
        "#   steps. The input to the Core is the concatenation of the Encoder's output\n",
        "#   and the previous output of the Core (labeled \"Hidden(t)\" below, where \"t\" is\n",
        "#   the processing step).\n",
        "# - A \"Decoder\" graph net, which independently decodes the edge, node, and\n",
        "#   global attributes (does not compute relations etc.), on each\n",
        "#   message-passing step.\n",
        "#\n",
        "#                     Hidden(t)   Hidden(t+1)\n",
        "#                        |            ^\n",
        "#           *---------*  |  *------*  |  *---------*\n",
        "#           |         |  |  |      |  |  |         |\n",
        "# Input ---\u003e| Encoder |  *-\u003e| Core |--*-\u003e| Decoder |---\u003e Output(t)\n",
        "#           |         |----\u003e|      |     |         |\n",
        "#           *---------*     *------*     *---------*\n",
        "#\n",
        "# The model is trained by supervised learning. Input mass-spring systems are\n",
        "# procedurally generated, where the nodes represent the positions, velocities,\n",
        "# and indicators of whether the mass is fixed in space or free to move, the\n",
        "# edges represent the spring constant and spring rest length, and the global\n",
        "# attribute represents the variable coefficient of gravitational acceleration.\n",
        "# The outputs/targets have the same structure, with the nodes representing the\n",
        "# masses' next-step states.\n",
        "#\n",
        "# The training loss is computed on the output of each processing step. The\n",
        "# reason for this is to encourage the model to try to solve the problem in as\n",
        "# few steps as possible. It also helps make the output of intermediate steps\n",
        "# more interpretable.\n",
        "#\n",
        "# There's no need for a separate evaluate dataset because the inputs are\n",
        "# never repeated, so the training loss is the measure of performance on graphs\n",
        "# from the input distribution.\n",
        "#\n",
        "# We also evaluate how well the models generalize to systems which are one mass\n",
        "# larger, and smaller, than those from the training distribution. The loss is\n",
        "# computed as the mean over a 50-step rollout, where each step's input is the\n",
        "# the previous step's output.\n",
        "#\n",
        "# Variables with the suffix _tr are training parameters, and variables with the\n",
        "# suffix _ge are test/generalization parameters.\n",
        "#\n",
        "# After around 10000-20000 training iterations the model reaches good\n",
        "# performance on mass-spring systems with 5-8 masses.\n",
        "\n",
        "tf.reset_default_graph()\n",
        "\n",
        "rand = np.random.RandomState(SEED)\n",
        "\n",
        "# Model parameters.\n",
        "num_processing_steps_tr = 1\n",
        "num_processing_steps_ge = 1\n",
        "\n",
        "# Data / training parameters.\n",
        "num_training_iterations = 100000\n",
        "batch_size_tr = 256\n",
        "batch_size_ge = 100\n",
        "num_time_steps = 50\n",
        "step_size = 0.1\n",
        "num_masses_min_max_tr = (5, 9)\n",
        "dist_between_masses_min_max_tr = (0.2, 1.0)\n",
        "\n",
        "# Create the model.\n",
        "model = models.EncodeProcessDecode(node_output_size=2)\n",
        "\n",
        "# Data.\n",
        "# Base graphs for training.\n",
        "num_masses_tr = rand.randint(*num_masses_min_max_tr, size=batch_size_tr)\n",
        "dist_between_masses_tr = rand.uniform(\n",
        "    *dist_between_masses_min_max_tr, size=batch_size_tr)\n",
        "static_graph_tr = [\n",
        "    base_graph(n, d) for n, d in zip(num_masses_tr, dist_between_masses_tr)\n",
        "]\n",
        "base_graph_tr = utils_tf.data_dicts_to_graphs_tuple(static_graph_tr)\n",
        "# Base graphs for testing.\n",
        "# 4 masses 1m apart in a chain like structure.\n",
        "base_graph_4_ge = utils_tf.data_dicts_to_graphs_tuple(\n",
        "    [base_graph(4, 0.5)] * batch_size_ge)\n",
        "# 9 masses 0.5m apart in a chain like structure.\n",
        "base_graph_9_ge = utils_tf.data_dicts_to_graphs_tuple(\n",
        "    [base_graph(9, 0.5)] * batch_size_ge)\n",
        "# True physics simulator for data generation.\n",
        "simulator = SpringMassSimulator(step_size=step_size)\n",
        "# Training.\n",
        "# Generate a training trajectory by adding noise to initial\n",
        "# position, spring constants and gravity\n",
        "initial_conditions_tr, true_trajectory_tr = generate_trajectory(\n",
        "    simulator,\n",
        "    base_graph_tr,\n",
        "    num_time_steps,\n",
        "    step_size,\n",
        "    node_noise_level=0.04,\n",
        "    edge_noise_level=5.0,\n",
        "    global_noise_level=1.0)\n",
        "# Random start step.\n",
        "t = tf.random_uniform([], minval=0, maxval=num_time_steps - 1, dtype=tf.int32)\n",
        "input_graph_tr = initial_conditions_tr.replace(nodes=true_trajectory_tr[t])\n",
        "target_nodes_tr = true_trajectory_tr[t + 1]\n",
        "output_ops_tr = model(input_graph_tr, num_processing_steps_tr)\n",
        "# Test data: 4-mass string.\n",
        "initial_conditions_4_ge, true_trajectory_4_ge = generate_trajectory(\n",
        "    lambda x: model(x, num_processing_steps_ge),\n",
        "    base_graph_4_ge,\n",
        "    num_time_steps,\n",
        "    step_size,\n",
        "    node_noise_level=0.04,\n",
        "    edge_noise_level=5.0,\n",
        "    global_noise_level=1.0)\n",
        "_, true_nodes_rollout_4_ge = roll_out_physics(\n",
        "    simulator, initial_conditions_4_ge, num_time_steps, step_size)\n",
        "_, predicted_nodes_rollout_4_ge = roll_out_physics(\n",
        "    lambda x: model(x, num_processing_steps_ge), initial_conditions_4_ge,\n",
        "    num_time_steps, step_size)\n",
        "# Test data: 9-mass string.\n",
        "initial_conditions_9_ge, true_trajectory_9_ge = generate_trajectory(\n",
        "    lambda x: model(x, num_processing_steps_ge),\n",
        "    base_graph_9_ge,\n",
        "    num_time_steps,\n",
        "    step_size,\n",
        "    node_noise_level=0.04,\n",
        "    edge_noise_level=5.0,\n",
        "    global_noise_level=1.0)\n",
        "_, true_nodes_rollout_9_ge = roll_out_physics(\n",
        "    simulator, initial_conditions_9_ge, num_time_steps, step_size)\n",
        "_, predicted_nodes_rollout_9_ge = roll_out_physics(\n",
        "    lambda x: model(x, num_processing_steps_ge), initial_conditions_9_ge,\n",
        "    num_time_steps, step_size)\n",
        "\n",
        "# Training loss.\n",
        "loss_ops_tr = create_loss_ops(target_nodes_tr, output_ops_tr)\n",
        "# Training loss across processing steps.\n",
        "loss_op_tr = sum(loss_ops_tr) / num_processing_steps_tr\n",
        "# Test/generalization loss: 4-mass.\n",
        "loss_op_4_ge = tf.reduce_mean(\n",
        "    tf.reduce_sum(\n",
        "        (predicted_nodes_rollout_4_ge[..., 2:4] -\n",
        "         true_nodes_rollout_4_ge[..., 2:4])**2,\n",
        "        axis=-1))\n",
        "# Test/generalization loss: 9-mass string.\n",
        "loss_op_9_ge = tf.reduce_mean(\n",
        "    tf.reduce_sum(\n",
        "        (predicted_nodes_rollout_9_ge[..., 2:4] -\n",
        "         true_nodes_rollout_9_ge[..., 2:4])**2,\n",
        "        axis=-1))\n",
        "\n",
        "# Optimizer.\n",
        "learning_rate = 1e-3\n",
        "optimizer = tf.train.AdamOptimizer(learning_rate)\n",
        "step_op = optimizer.minimize(loss_op_tr)\n",
        "\n",
        "input_graph_tr = make_all_runnable_in_session(input_graph_tr)\n",
        "initial_conditions_4_ge = make_all_runnable_in_session(initial_conditions_4_ge)\n",
        "initial_conditions_9_ge = make_all_runnable_in_session(initial_conditions_9_ge)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "TpABTYk0Ap-V"
      },
      "outputs": [],
      "source": [
        "#@title Reset session  { form-width: \"30%\" }\n",
        "\n",
        "# This cell resets the Tensorflow session, but keeps the same computational\n",
        "# graph.\n",
        "\n",
        "try:\n",
        "  sess.close()\n",
        "except NameError:\n",
        "  pass\n",
        "sess = tf.Session()\n",
        "sess.run(tf.global_variables_initializer())\n",
        "\n",
        "last_iteration = 0\n",
        "logged_iterations = []\n",
        "losses_tr = []\n",
        "losses_4_ge = []\n",
        "losses_9_ge = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "4PCPvXiHA7q7"
      },
      "outputs": [],
      "source": [
        "#@title Run training  { form-width: \"30%\" }\n",
        "\n",
        "# You can interrupt this cell's training loop at any time, and visualize the\n",
        "# intermediate results by running the next cell (below). You can then resume\n",
        "# training by simply executing this cell again.\n",
        "\n",
        "# How much time between logging and printing the current results.\n",
        "log_every_seconds = 20\n",
        "\n",
        "print(\"# (iteration number), T (elapsed seconds), \"\n",
        "      \"Ltr (training 1-step loss), \"\n",
        "      \"Lge4 (test/generalization rollout loss for 4-mass strings), \"\n",
        "      \"Lge9 (test/generalization rollout loss for 9-mass strings)\")\n",
        "\n",
        "start_time = time.time()\n",
        "last_log_time = start_time\n",
        "for iteration in range(last_iteration, num_training_iterations):\n",
        "  last_iteration = iteration\n",
        "  train_values = sess.run({\n",
        "      \"step\": step_op,\n",
        "      \"loss\": loss_op_tr,\n",
        "      \"input_graph\": input_graph_tr,\n",
        "      \"target_nodes\": target_nodes_tr,\n",
        "      \"outputs\": output_ops_tr\n",
        "  })\n",
        "  the_time = time.time()\n",
        "  elapsed_since_last_log = the_time - last_log_time\n",
        "  if elapsed_since_last_log \u003e log_every_seconds:\n",
        "    last_log_time = the_time\n",
        "    test_values = sess.run({\n",
        "        \"loss_4\": loss_op_4_ge,\n",
        "        \"true_rollout_4\": true_nodes_rollout_4_ge,\n",
        "        \"predicted_rollout_4\": predicted_nodes_rollout_4_ge,\n",
        "        \"loss_9\": loss_op_9_ge,\n",
        "        \"true_rollout_9\": true_nodes_rollout_9_ge,\n",
        "        \"predicted_rollout_9\": predicted_nodes_rollout_9_ge\n",
        "    })\n",
        "    elapsed = time.time() - start_time\n",
        "    losses_tr.append(train_values[\"loss\"])\n",
        "    losses_4_ge.append(test_values[\"loss_4\"])\n",
        "    losses_9_ge.append(test_values[\"loss_9\"])\n",
        "    logged_iterations.append(iteration)\n",
        "    print(\"# {:05d}, T {:.1f}, Ltr {:.4f}, Lge4 {:.4f}, Lge9 {:.4f}\".format(\n",
        "        iteration, elapsed, train_values[\"loss\"], test_values[\"loss_4\"],\n",
        "        test_values[\"loss_9\"]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "j1FgiIm-pmRq"
      },
      "outputs": [],
      "source": [
        "#@title Visualize loss curves  { form-width: \"30%\" }\n",
        "\n",
        "# This cell visualizes the results of training. You can visualize the\n",
        "# intermediate results by interrupting execution of the cell above, and running\n",
        "# this cell. You can then resume training by simply executing the above cell\n",
        "# again.\n",
        "\n",
        "def get_node_trajectories(rollout_array, batch_size):  # pylint: disable=redefined-outer-name\n",
        "  return np.split(rollout_array[..., :2], batch_size, axis=1)\n",
        "\n",
        "\n",
        "fig = plt.figure(1, figsize=(18, 3))\n",
        "fig.clf()\n",
        "x = np.array(logged_iterations)\n",
        "# Next-step Loss.\n",
        "y = losses_tr\n",
        "ax = fig.add_subplot(1, 3, 1)\n",
        "ax.plot(x, y, \"k\")\n",
        "ax.set_title(\"Next step loss\")\n",
        "# Rollout 5 loss.\n",
        "y = losses_4_ge\n",
        "ax = fig.add_subplot(1, 3, 2)\n",
        "ax.plot(x, y, \"k\")\n",
        "ax.set_title(\"Rollout loss: 5-mass string\")\n",
        "# Rollout 9 loss.\n",
        "y = losses_9_ge\n",
        "ax = fig.add_subplot(1, 3, 3)\n",
        "ax.plot(x, y, \"k\")\n",
        "ax.set_title(\"Rollout loss: 9-mass string\")\n",
        "\n",
        "# Visualize trajectories.\n",
        "true_rollouts_4 = get_node_trajectories(test_values[\"true_rollout_4\"],\n",
        "                                        batch_size_ge)\n",
        "predicted_rollouts_4 = get_node_trajectories(test_values[\"predicted_rollout_4\"],\n",
        "                                             batch_size_ge)\n",
        "true_rollouts_9 = get_node_trajectories(test_values[\"true_rollout_9\"],\n",
        "                                        batch_size_ge)\n",
        "predicted_rollouts_9 = get_node_trajectories(test_values[\"predicted_rollout_9\"],\n",
        "                                             batch_size_ge)\n",
        "\n",
        "true_rollouts = true_rollouts_4\n",
        "predicted_rollouts = predicted_rollouts_4\n",
        "true_rollouts = true_rollouts_9\n",
        "predicted_rollouts = predicted_rollouts_9\n",
        "\n",
        "num_graphs = len(true_rollouts)\n",
        "num_time_steps = true_rollouts[0].shape[0]\n",
        "\n",
        "# Plot state sequences.\n",
        "max_graphs_to_plot = 1\n",
        "num_graphs_to_plot = min(num_graphs, max_graphs_to_plot)\n",
        "num_steps_to_plot = 24\n",
        "max_time_step = num_time_steps - 1\n",
        "step_indices = np.floor(np.linspace(0, max_time_step,\n",
        "                                    num_steps_to_plot)).astype(int).tolist()\n",
        "w = 6\n",
        "h = int(np.ceil(num_steps_to_plot / w))\n",
        "fig = plt.figure(101, figsize=(18, 8))\n",
        "fig.clf()\n",
        "for i, (true_rollout, predicted_rollout) in enumerate(\n",
        "    zip(true_rollouts, predicted_rollouts)):\n",
        "  xys = np.hstack([predicted_rollout, true_rollout]).reshape([-1, 2])\n",
        "  xs = xys[:, 0]\n",
        "  ys = xys[:, 1]\n",
        "  b = 0.05\n",
        "  xmin = xs.min() - b * xs.ptp()\n",
        "  xmax = xs.max() + b * xs.ptp()\n",
        "  ymin = ys.min() - b * ys.ptp()\n",
        "  ymax = ys.max() + b * ys.ptp()\n",
        "  if i \u003e= num_graphs_to_plot:\n",
        "    break\n",
        "  for j, step_index in enumerate(step_indices):\n",
        "    iax = i * w + j + 1\n",
        "    ax = fig.add_subplot(h, w, iax)\n",
        "    ax.plot(\n",
        "        true_rollout[step_index, :, 0],\n",
        "        true_rollout[step_index, :, 1],\n",
        "        \"k\",\n",
        "        label=\"True\")\n",
        "    ax.plot(\n",
        "        predicted_rollout[step_index, :, 0],\n",
        "        predicted_rollout[step_index, :, 1],\n",
        "        \"r\",\n",
        "        label=\"Predicted\")\n",
        "    ax.set_title(\"Example {:02d}: frame {:03d}\".format(i, step_index))\n",
        "    ax.set_xlim(xmin, xmax)\n",
        "    ax.set_ylim(ymin, ymax)\n",
        "    ax.set_xticks([])\n",
        "    ax.set_yticks([])\n",
        "    if j == 0:\n",
        "      ax.legend(loc=3)\n",
        "\n",
        "# Plot x and y trajectories over time.\n",
        "max_graphs_to_plot = 3\n",
        "num_graphs_to_plot = min(len(true_rollouts), max_graphs_to_plot)\n",
        "w = 2\n",
        "h = num_graphs_to_plot\n",
        "fig = plt.figure(102, figsize=(18, 12))\n",
        "fig.clf()\n",
        "for i, (true_rollout, predicted_rollout) in enumerate(\n",
        "    zip(true_rollouts, predicted_rollouts)):\n",
        "  if i \u003e= num_graphs_to_plot:\n",
        "    break\n",
        "  t = np.arange(num_time_steps)\n",
        "  for j in range(2):\n",
        "    coord_string = \"x\" if j == 0 else \"y\"\n",
        "    iax = i * 2 + j + 1\n",
        "    ax = fig.add_subplot(h, w, iax)\n",
        "    ax.plot(t, true_rollout[..., j], \"k\", label=\"True\")\n",
        "    ax.plot(t, predicted_rollout[..., j], \"r\", label=\"Predicted\")\n",
        "    ax.set_xlabel(\"Time\")\n",
        "    ax.set_ylabel(\"{} coordinate\".format(coord_string))\n",
        "    ax.set_title(\"Example {:02d}: Predicted vs actual coords over time\".format(\n",
        "        i))\n",
        "    ax.set_frame_on(False)\n",
        "    if i == 0 and j == 1:\n",
        "      handles, labels = ax.get_legend_handles_labels()\n",
        "      unique_labels = []\n",
        "      unique_handles = []\n",
        "      for i, (handle, label) in enumerate(zip(handles, labels)):  # pylint: disable=redefined-outer-name\n",
        "        if label not in unique_labels:\n",
        "          unique_labels.append(label)\n",
        "          unique_handles.append(handle)\n",
        "      ax.legend(unique_handles, unique_labels, loc=3)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "physics.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
